from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split
loader = ChickenpoxDatasetLoader()
dataset = loader.get_dataset()
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)PyTorch Geometric Temporal
PyTorch Geometric Temporal: Spatiotemporal Signal Processing with Neural Machine Learning Models
Applications
Epidemiological Forecasting
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN
class RecurrentGCN(torch.nn.Module):
def __init__(self, node_features):
super(RecurrentGCN, self).__init__()
self.recurrent = DCRNN(node_features, 32, 1)
self.linear = torch.nn.Linear(32, 1)
def forward(self, x, edge_index, edge_weight):
h = self.recurrent(x, edge_index, edge_weight)
h = F.relu(h)
h = self.linear(h)
return hfrom tqdm import tqdm
model = RecurrentGCN(node_features = 4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
for epoch in tqdm(range(200)):
cost = 0
for time, snapshot in enumerate(train_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost.backward()
optimizer.step()
optimizer.zero_grad()
###########################################################
# I added this to check the shape.
print(y_hat.shape,snapshot.y.shape,(y_hat-snapshot.y).shape)100%|██████████| 200/200 [02:40<00:00, 1.24it/s]
torch.Size([20, 1]) torch.Size([20]) torch.Size([20, 20])
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))
# >>> MSE: 1.0232MSE: 1.0247
Shape Check (1)
a = torch.randn(20, 1)b = torch.randn(20)c = a-bprint(a.size(),b.size(),c.size())torch.Size([20, 1]) torch.Size([20]) torch.Size([20, 20])
Doesn’t it have to ‘y_hat’ be the same shape as snapshot.y?
- If we want to compare the y_hat from the model with the values y, the same shape is appropriate to evaluate.
from tqdm import tqdm
model = RecurrentGCN(node_features = 4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
for epoch in tqdm(range(200)):
cost = 0
for time, snapshot in enumerate(train_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr).reshape(-1)
cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost.backward()
optimizer.step()
optimizer.zero_grad()
###########################################################
# I added this to check the shape.
print(y_hat.shape,snapshot.y.shape,(y_hat-snapshot.y).shape)100%|██████████| 200/200 [01:27<00:00, 2.30it/s]
torch.Size([20]) torch.Size([20]) torch.Size([20])
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))
# >>> MSE: 1.0232MSE: 1.2844
Shape Check (2)
a = torch.randn(20, 1).reshape(-1)b = torch.randn(20)c = a-bprint(a.size(),b.size(),c.size())torch.Size([20]) torch.Size([20]) torch.Size([20])
Web Traffic Prediction
from torch_geometric_temporal.dataset import WikiMathsDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split
loader = WikiMathsDatasetLoader()
dataset = loader.get_dataset(lags=14)
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.5)import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import GConvGRU
class RecurrentGCN(torch.nn.Module):
def __init__(self, node_features, filters):
super(RecurrentGCN, self).__init__()
self.recurrent = GConvGRU(node_features, filters, 2)
self.linear = torch.nn.Linear(filters, 1)
def forward(self, x, edge_index, edge_weight):
h = self.recurrent(x, edge_index, edge_weight)
h = F.relu(h)
h = self.linear(h)
return hfrom tqdm import tqdm
model = RecurrentGCN(node_features=14, filters=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
for epoch in tqdm(range(50)):
for time, snapshot in enumerate(train_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
cost = torch.mean((y_hat-snapshot.y)**2)
cost.backward()
optimizer.step()
optimizer.zero_grad()
###########################################################
# I added this to check the shape.
print(y_hat.shape,snapshot.y.shape,(y_hat-snapshot.y).shape)100%|██████████| 50/50 [31:26<00:00, 37.73s/it]
torch.Size([1068, 1]) torch.Size([1068]) torch.Size([1068, 1068])
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))
# >>> MSE: 0.7760MSE: 0.7939
Shape Check (1)
a = torch.randn(1068, 1)b = torch.randn(1068)c = a-bprint(a.size(),b.size(),c.size())torch.Size([1068, 1]) torch.Size([1068]) torch.Size([1068, 1068])
If the code changes the shape of y_hat?
from tqdm import tqdm
model = RecurrentGCN(node_features=14, filters=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
for epoch in tqdm(range(50)):
for time, snapshot in enumerate(train_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
cost = torch.mean((y_hat-snapshot.y)**2)
cost.backward()
optimizer.step()
optimizer.zero_grad()
###########################################################
# I added this to check the shape.
print(y_hat.shape,snapshot.y.shape,(y_hat-snapshot.y).shape)100%|██████████| 50/50 [36:39<00:00, 43.99s/it]
torch.Size([1068, 1]) torch.Size([1068]) torch.Size([1068, 1068])
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))
# >>> MSE: 0.7760MSE: 0.7807
Shape Check (2)
a = torch.randn(1068, 1).reshape(-1)b = torch.randn(1068)c = a-bprint(a.size(),b.size(),c.size())torch.Size([1068]) torch.Size([1068]) torch.Size([1068])
References
@inproceedings{rozemberczki2021pytorch, author = {Benedek Rozemberczki and Paul Scherer and Yixuan He and George Panagopoulos and Alexander Riedel and Maria Astefanoaei and Oliver Kiss and Ferenc Beres and and Guzman Lopez and Nicolas Collignon and Rik Sarkar}, title = {{PyTorch Geometric Temporal: Spatiotemporal Signal Processing with Neural Machine Learning Models}}, year = {2021}, booktitle={Proceedings of the 30th ACM International Conference on Information and Knowledge Management}, pages = {4564–4573}, }